"""
Phase 6: Variance Decomposition
=================================
Baron-Kenny mediation analysis + Shapley-Owen R² decomposition.

Table 9: Mediation paths (S-I gap → CA, gross_ifi → income_bal → CA, debt_assets → income_bal)
Table 10: Shapley-Owen R² decomposition across {Z, S-I, gross_positions} groups
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
from itertools import combinations
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    return f"{val:.4f}{stars(p)}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label, quiet=False):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        if not quiet:
            print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    if not quiet:
        print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
        for i, name in enumerate(x_vars):
            if name in DEMO_VARS or 'Z_' in name:
                sig = stars(gls.pvalues[i])
                print(f"    {name:30s} {gls.beta[i]:10.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def compute_attenuation(baseline, mediated, var='Z_1'):
    base_key = f'{var}_coef'
    if base_key not in baseline or base_key not in mediated:
        return None
    b_base = baseline[base_key]
    b_med = mediated[base_key]
    if abs(b_base) < 1e-10:
        return None
    return (b_base - b_med) / b_base * 100


def get_r2(df, y_var, x_vars):
    """Get R² for a given specification (quiet)."""
    r = run_gls(df, y_var, x_vars, 'tmp', quiet=True)
    if r is None:
        return np.nan
    return r['r_squared']


def main():
    print("=" * 70)
    print("PHASE 6: VARIANCE DECOMPOSITION")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]

    # ══════════════════════════════════════════════════════════════════
    # TABLE 9: BARON-KENNY MEDIATION PATHS
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 9: MEDIATION PATHS")
    print("=" * 50)

    mediation_summary = []

    # ── Path A: Z → S-I gap → CA ──────────────────────────────────
    print("\n--- Path A: Z → S-I gap → CA ---")
    si_var = 'savings_investment_gap'
    if si_var in df.columns and df[si_var].notna().sum() > 200:
        r_step1a = run_gls(df, si_var, DEMO_VARS + controls, 'A1: Z → S-I gap')
        r_step2a = run_gls(df, 'ca_gdp', DEMO_VARS + controls, 'A2: Z → CA (baseline)')
        r_step3a = run_gls(df, 'ca_gdp', DEMO_VARS + [si_var] + controls, 'A3: Z + S-I → CA')

        if r_step2a and r_step3a:
            for zvar in ['Z_1', 'Z_2', 'Z_3']:
                atten = compute_attenuation(r_step2a, r_step3a, zvar)
                if atten is not None:
                    mediation_summary.append({
                        'path': 'A: Z → S-I → CA',
                        'variable': zvar,
                        'baseline': r_step2a[f'{zvar}_coef'],
                        'mediated': r_step3a[f'{zvar}_coef'],
                        'attenuation': atten,
                    })
                    print(f"  {zvar}: {r_step2a[f'{zvar}_coef']:.4f} → "
                          f"{r_step3a[f'{zvar}_coef']:.4f} (atten: {atten:.1f}%)")
    else:
        print("  savings_investment_gap not available")
        r_step2a = run_gls(df, 'ca_gdp', DEMO_VARS + controls, 'A2: Z → CA (baseline)')

    # ── Path B: Z → gross_ifi → income_balance → CA ───────────────
    print("\n--- Path B: Z → gross_ifi → income_balance ---")
    if 'income_balance_gdp' in df.columns and df['income_balance_gdp'].notna().sum() > 200:
        r_step1b = run_gls(df, 'gross_ifi', DEMO_VARS + controls, 'B1: Z → gross_ifi')
        r_step2b = run_gls(df, 'income_balance_gdp', DEMO_VARS + controls,
                           'B2: Z → income_bal (baseline)')
        r_step3b = run_gls(df, 'income_balance_gdp', DEMO_VARS + ['gross_ifi'] + controls,
                           'B3: Z + gross_ifi → income_bal')

        if r_step2b and r_step3b:
            for zvar in ['Z_1', 'Z_2', 'Z_3']:
                atten = compute_attenuation(r_step2b, r_step3b, zvar)
                if atten is not None:
                    mediation_summary.append({
                        'path': 'B: Z → gross_ifi → income_bal',
                        'variable': zvar,
                        'baseline': r_step2b[f'{zvar}_coef'],
                        'mediated': r_step3b[f'{zvar}_coef'],
                        'attenuation': atten,
                    })
                    print(f"  {zvar}: {r_step2b[f'{zvar}_coef']:.4f} → "
                          f"{r_step3b[f'{zvar}_coef']:.4f} (atten: {atten:.1f}%)")

    # ── Path C: Z → debt_assets → income_balance ──────────────────
    print("\n--- Path C: Z → debt_assets → income_balance ---")
    if ('income_balance_gdp' in df.columns and 'debt_assets_gdp' in df.columns
            and df['debt_assets_gdp'].notna().sum() > 200):
        r_step1c = run_gls(df, 'debt_assets_gdp', DEMO_VARS + controls,
                           'C1: Z → debt_assets')
        r_step3c = run_gls(df, 'income_balance_gdp',
                           DEMO_VARS + ['debt_assets_gdp'] + controls,
                           'C3: Z + debt_assets → income_bal')

        if r_step2b and r_step3c:
            for zvar in ['Z_1', 'Z_2', 'Z_3']:
                atten = compute_attenuation(r_step2b, r_step3c, zvar)
                if atten is not None:
                    mediation_summary.append({
                        'path': 'C: Z → debt_assets → income_bal',
                        'variable': zvar,
                        'baseline': r_step2b[f'{zvar}_coef'],
                        'mediated': r_step3c[f'{zvar}_coef'],
                        'attenuation': atten,
                    })
                    print(f"  {zvar}: {r_step2b[f'{zvar}_coef']:.4f} → "
                          f"{r_step3c[f'{zvar}_coef']:.4f} (atten: {atten:.1f}%)")

    # Write mediation summary
    lines = ["# Table 9: Mediation Summary\n"]
    lines.append("| Path | Variable | Baseline | Mediated | Attenuation % |")
    lines.append("|:---|:---|---:|---:|---:|")
    for m in mediation_summary:
        lines.append(f"| {m['path']} | {m['variable']} | "
                     f"{m['baseline']:.4f} | {m['mediated']:.4f} | {m['attenuation']:.1f}% |")
    lines.append("\n*Attenuation = (baseline - mediated) / baseline × 100.*")
    lines.append("*Baron & Kenny (1986) mediation framework with PanelGLS.*")

    (OUT_TABLES / "mediation_paths.md").write_text('\n'.join(lines))
    print(f"\n  Saved: mediation_paths.md")

    # ══════════════════════════════════════════════════════════════════
    # TABLE 10: SHAPLEY-OWEN R² DECOMPOSITION
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 10: SHAPLEY-OWEN R² DECOMPOSITION")
    print("=" * 50)

    # Variable groups
    group_Z = DEMO_VARS
    group_SI = ['savings_investment_gap'] if 'savings_investment_gap' in df.columns else ['gross_national_savings_gdp', 'gross_investment_gdp']
    group_SI = [v for v in group_SI if v in df.columns and df[v].notna().sum() > 200]
    group_gross = ['gross_assets_gdp', 'gross_liab_gdp']
    group_gross = [v for v in group_gross if v in df.columns and df[v].notna().sum() > 200]

    groups = {'Z': group_Z, 'S-I': group_SI, 'Gross': group_gross}
    group_names = [k for k, v in groups.items() if v]

    if len(group_names) >= 2:
        # Common sample
        all_vars = controls.copy()
        for g in group_names:
            all_vars.extend(groups[g])
        common = df[['ca_gdp'] + all_vars + ['iso3', 'year']].dropna()
        print(f"  Common sample: {len(common)} obs, {common['iso3'].nunique()} countries")

        # Compute R² for all 2^K subsets
        shapley = {g: 0.0 for g in group_names}
        n_groups = len(group_names)

        for i in range(n_groups):
            g_i = group_names[i]
            # For each subset S not containing g_i
            others = [g for g in group_names if g != g_i]
            for size in range(len(others) + 1):
                for subset in combinations(others, size):
                    # R² with S
                    vars_s = controls.copy()
                    for g in subset:
                        vars_s.extend(groups[g])
                    r2_without = get_r2(common, 'ca_gdp', vars_s) if vars_s != controls else get_r2(common, 'ca_gdp', controls)

                    # R² with S ∪ {g_i}
                    vars_s_plus = vars_s + groups[g_i]
                    r2_with = get_r2(common, 'ca_gdp', vars_s_plus)

                    # Marginal contribution
                    if not np.isnan(r2_with) and not np.isnan(r2_without):
                        s = len(subset)
                        # Shapley weight: s! * (n-s-1)! / n!
                        import math
                        weight = (math.factorial(s) * math.factorial(n_groups - s - 1)) / math.factorial(n_groups)
                        shapley[g_i] += weight * (r2_with - r2_without)

        # Normalize
        total_shapley = sum(shapley.values())
        print(f"\n  Shapley R² contributions (total = {total_shapley:.4f}):")

        lines_s = ["# Table 10: Shapley-Owen R² Decomposition\n"]
        lines_s.append("| Variable Group | Variables | Shapley R² | Share (%) |")
        lines_s.append("|:---|:---|---:|---:|")
        for g in group_names:
            share = shapley[g] / total_shapley * 100 if total_shapley > 0 else 0
            vars_str = ', '.join(groups[g])
            lines_s.append(f"| {g} | {vars_str} | {shapley[g]:.4f} | {share:.1f}% |")
            print(f"    {g:10s}  R²={shapley[g]:.4f}  ({share:.1f}%)")

        lines_s.append(f"| **Total** | | **{total_shapley:.4f}** | **100%** |")
        lines_s.append(f"\n*Shapley-Owen decomposition of R² for ca_gdp.*")
        lines_s.append(f"*Controls ({', '.join(controls)}) included in all subsets.*")
        lines_s.append(f"*Common sample: N={len(common)}, {common['iso3'].nunique()} countries.*")

        (OUT_TABLES / "shapley_decomposition.md").write_text('\n'.join(lines_s))
        print(f"\n  Saved: shapley_decomposition.md")
    else:
        print("  Insufficient variable groups for Shapley decomposition")

    print("\n" + "=" * 70)
    print("PHASE 6 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
